#! /usr/bin/env python

############################################################################
# Copyright (c) 2009 Dr. Peter Bunting, Aberystwyth University
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# 
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
#
# Purpose:  A series of classes to generate image plots from the ptxt text 
#           file format exported for plotting from the rsgis library. 
#
# Author: Pete Bunting
# Email: pete.bunting@aber.ac.uk
# Date: 11/08/2009
# Version: 1.0
#
# History:
# Version 1.0 - Created.
#
#############################################################################

import os.path
import sys, re
import numpy as np
from pylab import *
try:
    import numpy as np
except ImportError:
    print('''Numpy is required for this script, not found.
Check: http://numpy.scipy.org/''')
    exit()
try:
    import matplotlib.pyplot as plt
except ImportError:
    print('''MatPlotLib is required for this script, not found.
Check: http://matplotlib.sourceforge.net/''')
    exit()
try:
    from enthought.mayavi import mlab
    importMlab = True
except ImportError:
    importMlab = False

def checkforHash(line):
    foundHash = False
    for i in range(len(line)):
        if line[i] == '#':
            foundHash = True
    return foundHash

def stringTokenizer(line, delimiter):
    tokens = list()
    token = str()
    for i in range(len(line)):
        if line[i] == delimiter:
            tokens.append(token)
            token = str()
        else:
            token = token + line[i]
    tokens.append(token)
    return tokens

class PlotFreq2D (object):
    
    def plot(self, data, outputFile, outFormat):
        fig = plt.figure(dpi=80)
        plt.hist(data,bins=30)
        plt.savefig(outputFile, format=outFormat)
    
    def readPTXTFile(self, inputFile, data):
        dataFile = open(inputFile, 'r')
        for eachLine in dataFile:
            comment = checkforHash(eachLine)
            if comment == False:
                data.append(float(eachLine.strip())) 
        dataFile.close()
    
    def help(self):
        plotObj = RSGISLibPlot();
        plotObj.help()
        print('Additional Parameters for Freq2D.')
        print('\t<FORMAT> - output file format of the plot (PNG | PDF)')
    
    def createPlot(self):
        print("PlotFreq2D - Basic Implementation..")
        numargs = len(sys.argv)
        if numargs == 4:
            inputFile = sys.argv[1].strip()
            outputFile = sys.argv[2].strip()
            format = sys.argv[3].strip()
            
            data = list()
            self.readPTXTFile(inputFile, data)
            self.plot(data, outputFile, format)
        else:
            self.help()
        
class PlotFreq3D (object):
    
    def createPlot(self):
        print("PlotFreq3D - NOT IMPLEMENTED YET!!")

class PlotLines2D (object):
    
    def createPlot(self):
        print("PlotLines2D - NOT IMPLEMENTED YET!!")

class PlotLines3D (object):
    
    def createPlot(self):
        print("PlotLines3D - NOT IMPLEMENTED YET!!")

class PlotDensity (object):
    
    def createPlot(self):
        print("PlotDensity - NOT IMPLEMENTED YET!!")

class Plot2DScatter (object):
    
    def plot(self, dataX, dataY, outputFile, outFormat):
        fig = plt.figure(dpi=80)
        plt.scatter(dataX, dataY, marker='o')
        plt.savefig(outputFile, format=outFormat)
    
    def readPTXTFile(self, inputFile, dataX, dataY):
        dataFile = open(inputFile, 'r')
        for eachLine in dataFile:
            comment = checkforHash(eachLine)
            if comment == False:
                tokens = stringTokenizer(eachLine, ',')
                if len(tokens) == 2:
                    dataX.append(float(tokens[0].strip()))
                    dataY.append(float(tokens[1].strip()))
                else:
                    raise Exception('Not enough tokens in a line')
        dataFile.close()
    
    def help(self):
        plotObj = RSGISLibPlot();
        plotObj.help()
        print('Additional Parameters for Freq2D.')
        print('\t<FORMAT> - output file format of the plot (PNG | PDF)')
    
    def createPlot(self):
        print("Plot2DScatter - Basic Implementation..")
        numargs = len(sys.argv)
        if numargs == 4:
            inputFile = sys.argv[1].strip()
            outputFile = sys.argv[2].strip()
            format = sys.argv[3].strip()
            
            dataX = list()
            dataY = list()
            try:
                self.readPTXTFile(inputFile, dataX, dataY)
                self.plot(dataX, dataY, outputFile, format)
            except Exception as e:
                print("Error: ", e)
        else:
            self.help()
        
class PlotClass2DScatter (object):
    def plot(self, dataX, dataY, classNames, classNamesLowLim, classNamesUpLim, outputFile, outFormat):
        fig = plt.figure(dpi=80)
                
        for i in range(len(classNames)):
            classCol = self.stdCol(i)
            className = classNames[i]
            lowLim = int(classNamesLowLim[i])
            upLim = int(classNamesUpLim[i])
            dataA = []
            dataB = []
            for j in range(upLim-lowLim):
                dataA.append(dataX[lowLim + j])
                dataB.append(dataY[lowLim + j])
            plt.scatter(dataA, dataB, marker='o',edgecolor=classCol, facecolor=classCol, label=className)
            
        plt.legend()
        plt.savefig(outputFile, format=outFormat)
    
    def readPTXTFile(self, inputFile, dataRawClass, dataRawX, dataRawY, classNames):
        dataFile = open(inputFile, 'r')
        lineNo = 0
        for eachLine in dataFile:
            if lineNo == 1: # Class names stored in second line of file
                eachLine = re.sub('#', '', eachLine)
                tokensC = stringTokenizer(eachLine, ',')
                for i in range(len(tokensC)):
                    classNames.append(str(tokensC[i].strip()))
            else:
                comment = checkforHash(eachLine)
                if comment == False:
                    tokens = stringTokenizer(eachLine, ',')
                    if len(tokens) == 3:
                        dataRawClass.append(str(tokens[0].strip()))
                        dataRawX.append(float(tokens[1].strip()))
                        dataRawY.append(float(tokens[2].strip()))
                    else:
                        raise Exception('Not enough tokens in a line')
            lineNo = lineNo + 1 
        dataFile.close()
        
    def sortData(self, dataRawClass, dataRawX, dataRawY, dataX, dataY, classNames, classNamesLowLim, classNamesUpLim):
        totalData = 0
        
        for c in range(len(classNames)):
            currentClass = classNames[c]
            numDataClass = 0
            for i in range(len(dataRawClass)):
                elementClassName = dataRawClass[i]
                if elementClassName == currentClass:
                    dataX.append(dataRawX[i])
                    dataY.append(dataRawY[i])
                    numDataClass = numDataClass + 1
                                
            classNamesLowLim.append(totalData)
            classNamesUpLim.append(totalData + numDataClass)
            
            totalData = totalData + numDataClass
    def stdCol(self, i):
        sdColScheme = list()
        sdColScheme.append("r")
        sdColScheme.append("g")
        sdColScheme.append("b")
        sdColScheme.append("c")
        sdColScheme.append("m")
        sdColScheme.append("y")
        
        if i < len(sdColScheme):
            colour = sdColScheme[i]
        else:
            colour = cm.jet(i*100)
        return colour
    
    def help(self):
        plotObj = RSGISLibPlot();
        plotObj.help()
        print('Additional Parameters for Class2DScatter.')
        print('\t<FORMAT> - output file format of the plot (PNG | PDF)')
    
    def createPlot(self):
        print("Class2DScatter ")
        numargs = len(sys.argv)
        if numargs == 4:
            inputFile = sys.argv[1].strip()
            outputFile = sys.argv[2].strip()
            format = sys.argv[3].strip()
            
            classNames = list()
            dataRawClass = list()
            dataRawX = list()
            dataRawY = list()
            
            dataX = list()
            dataY = list()
            classNamesLowLim = list()
            classNamesUpLim = list()
            
            try:
                self.readPTXTFile(inputFile, dataRawClass, dataRawX, dataRawY, classNames)
                self.sortData(dataRawClass, dataRawX, dataRawY, dataX, dataY, classNames, classNamesLowLim, classNamesUpLim)
                del dataRawClass
                del dataRawX
                del dataRawY
                self.plot(dataX, dataY, classNames, classNamesLowLim, classNamesUpLim, outputFile, format)
            except Exception as e:
                print("Error: ", e)
        else:
            self.help()

class Plot3DScatter (object):
    
    def plot(self, dataX, dataY, dataZ, outputFile, outFormat):
        dataXBaseZero = list()
        dataXMin = min(dataX);
        for item in dataX:
            dataXBaseZero.append(item - dataXMin)
        
        dataYBaseZero = list()
        dataYMin = min(dataY);
        for item in dataY:
            dataYBaseZero.append(item - dataYMin)
        
        dataXMax = max(dataXBaseZero)
        dataYMax = max(dataYBaseZero)
        ptsMaxXAxis = 0
        if dataXMax >= dataYMax:
            ptsMaxXAxis = dataXMax
        else:
            ptsMaxXAxis = dataYMax
        
        
        axisdimensionsPts = [-2,(ptsMaxXAxis+2),(min(dataZ)-2),(max(dataZ)+2)]
        
        plt.figure(dpi="200")
        
        plt.subplot(121)
        ylabel("Height")
        xlabel("X")
        plt.scatter(dataXBaseZero, dataZ, color='black', marker='o')
        axis(axisdimensionsPts)
        grid()
        
        plt.subplot(122)
        xlabel("Y")
        plt.scatter(dataYBaseZero, dataZ, color='black', marker='o')
        axis(axisdimensionsPts)
        grid()
        plt.savefig(outputFile, format=outFormat)
    
    def readPTXTFile(self, inputFile, dataX, dataY, dataZ):
        dataFile = open(inputFile, 'r')
        for eachLine in dataFile:
            comment = checkforHash(eachLine)
            if comment == False:
                tokens = stringTokenizer(eachLine, ',')
                if len(tokens) == 3:
                    dataX.append(float(tokens[0].strip()))
                    dataY.append(float(tokens[1].strip()))
                    dataZ.append(float(tokens[2].strip()))
                else:
                    raise Exception('Not enough tokens in a line')
        dataFile.close()
    
    def help(self):
        plotObj = RSGISLibPlot();
        plotObj.help()
        print('Additional Parameters for scatter3D.')
        print('\t<FORMAT> - output file format of the plot (PNG | PDF)')
    
    def createPlot(self):
        print("Plot3DScatter - Basic Implementation..")
        numargs = len(sys.argv)
        if numargs == 4:
            inputFile = sys.argv[1].strip()
            outputFile = sys.argv[2].strip()
            format = sys.argv[3].strip()
            
            dataX = list()
            dataY = list()
            dataZ = list()
            try:
                self.readPTXTFile(inputFile, dataX, dataY, dataZ)
                self.plot(dataX, dataY, dataZ, outputFile, format)
            except Exception as e:
                print("Error: ", e)
        else:
            self.help()

class Plotc2DScatter (object):

    def createPlot(self):
        print("Plotc2DScatter - NOT IMPLEMENTED YET!!")

class Plotc3DScatter (object):
    
    def createPlot(self):
        print("Plotc3DScatter - NOT IMPLEMENTED YET!!")

class PlotSurface (object):
    
    def plot(self, dataX, dataY, dataZ, outputFile):
        maxX = int(max(dataX))
        maxY = int(max(dataY))
        
        grid = list()
        for i in range(maxY):
            grid.append(list())
            for j in range(maxX):
                grid[i].append(0)
        #x = int(0)
        #y = int(0)
        for i in range(len(dataX)):
            x = dataX[i]-1
            y = dataY[i]-1
            grid[y][x]=dataZ[i]
        
        #print grid
        
        fig = mlab.figure()
        mlab.surf(grid, figure=fig)
        mlab.savefig(outputFile, figure=fig)
        mlab.show()
    
    def readPTXTFile(self, inputFile, dataX, dataY, dataZ):
        dataFile = open(inputFile, 'r')
        for eachLine in dataFile:
            comment = checkforHash(eachLine)
            if comment == False:
                tokens = stringTokenizer(eachLine, ',')
                if len(tokens) == 3:
                    dataX.append(int(tokens[0].strip()))
                    dataY.append(int(tokens[1].strip()))
                    dataZ.append(float(tokens[2].strip()))
                else:
                    raise Exception('Not enough tokens in a line')
        dataFile.close()
    
    def help(self):
        plotObj = RSGISLibPlot();
        plotObj.help()
    
    def createPlot(self):
        print("PlotSurface - Basic Implementation..")
        if importMlab == False:
            print('mayavi, from Enthough Python distribution is required for surface plots, not found')
            exit()
        numargs = len(sys.argv)
        if numargs == 4: 
            inputFile = sys.argv[1].strip()
            outputFile = sys.argv[2].strip()
            
            dataX = list()
            dataY = list()
            dataZ = list()
            try:
                print("Reading Data ")
                self.readPTXTFile(inputFile, dataX, dataY, dataZ)   
                self.plot(dataX, dataY, dataZ, outputFile)
            except Exception as e:
                print("Error: ", e)
        else:
            self.help()

class PlotcSurface (object):
    
    def createPlot(self):
        print("PlotcSurface - NOT IMPLEMENTED YET!!")


class RSGISLibPlot (object):

    def runPlot(self, inputFile):
        inputPTXTFile = open(inputFile, 'r')
        plottype = str()
        for line in inputPTXTFile:
            plottype = line.strip()
            break
        inputPTXTFile.close()
        
        if plottype == "#freq2D":
            #print "Plot freq2D"
            plotObj = PlotFreq2D()
            plotObj.createPlot()
        elif plottype == "#freq3D":
            print("Plot freq3D")
            plotObj = PlotFreq3D()
            plotObj.createPlot()
        elif plottype == "#lines2D":
            print("Plot lines2D")
            plotObj = PlotLines2D()
            plotObj.createPlot()
        elif plottype == "#lines3D":
            print("Plot lines3D")
            plotObj = PlotLines3D()
            plotObj.createPlot()
        elif plottype == "#Density":
            print("Plot Density")
            plotObj = PlotDensity()
            plotObj.createPlot()
        elif plottype == "#2DScatter":
            print("Plot 2DScatter")
            plotObj = Plot2DScatter()
            plotObj.createPlot()
        elif plottype == "#ClassScatter2D":
            print("Plot ClassScatter2D")
            plotObj = PlotClass2DScatter()
            plotObj.createPlot()
        elif plottype == "#3DScatter":
            print("Plot 3DScatter")
            plotObj = Plot3DScatter()
            plotObj.createPlot()
        elif plottype == "#c2DScatter":
            print("Plot c2DScatter")
            plotObj = Plotc2DScatter()
            plotObj.createPlot()
        elif plottype == "#c3DScatter":
            print("Plot c3DScatter")
            plotObj = Plotc3DScatter()
            plotObj.createPlot()
        elif plottype == "#Surface":
            print("Plot Surface")
            plotObj = PlotSurface()
            plotObj.createPlot()
        elif plottype == "#cSurface":
            print("Plot cSurface")
            plotObj = PlotcSurface()
            plotObj.createPlot()
        else:
            print("ERROR: Did not recognise the plot type \'" + plottype + "\'.")
            
    
    def run(self):
        numArgs = len(sys.argv)
        if numArgs > 2:
            inputFile = sys.argv[1].strip()
            
            if os.path.exists(inputFile):
                self.runPlot(inputFile)
            else:
                print('Input File does not exist')
        else:
            self.help()
    
    def help(self):
        print('rsgislibplot.py script generates plots from ptxt files generated from the ')
        print('@PACKAGE@ software library')
        print('Usage: python rsgislibplot.py <INPUT> <OUTPUT> [PLOT SPECIFIC]')
        print('\t<INPUT> - input ptxt file')
        print('\t<OUTPUT> - output plot')
        print('\nThis script was distributed with @RSGISLIB_PACKAGE_STRING@')
        print('Copyright (C) @RSGISLIB_COPYRIGHT_YEAR@ Peter Bunting and Daniel Clewley')
        print('For support please email @RSGISLIB_PACKAGE_BUGREPORT@')

if __name__ == '__main__':
    obj = RSGISLibPlot()
    obj.run()
